tensor([[[0.7800, 0.0000],
         [0.0000, 0.8900]],

        [[0.0800, 0.4600],
         [0.4500, 0.0200]],

        [[0.1400, 0.5400],
         [0.5500, 0.0900]]], grad_fn=<RoundBackward1>)